{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Running membership inference attacks on a Regression Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we will show how to run black-box membership attacks. This will be demonstrated on the Diabetes dataset from Scikitlearn (https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, os.path.abspath('..'))\n",
    "\n",
    "from art.utils import load_diabetes\n",
    "\n",
    "(x_train, y_train), (x_test, y_test), _, _ = load_diabetes(test_set=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train MLP regression model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Base model score:  0.5182246496026136\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "from art.estimators.regression.scikitlearn import ScikitlearnRegressor\n",
    "\n",
    "model = LinearRegression()\n",
    "model.fit(x_train, y_train)\n",
    "\n",
    "art_regressor = ScikitlearnRegressor(model)\n",
    "\n",
    "print('Base model score: ', model.score(x_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attack\n",
    "### Black-box attack\n",
    "The black-box attack basically trains an additional classifier (called the attack model) to predict the membership status of a sample. The loss of the target model is used as input to the learning process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Member accuracy: 0.7239819004524887\n",
      "Non-Member accuracy: 0.7963800904977376\n",
      "Accuracy: 0.7601809954751131\n"
     ]
    }
   ],
   "source": [
    "from art.attacks.inference.membership_inference import MembershipInferenceBlackBox\n",
    "\n",
    "\n",
    "bb_attack = MembershipInferenceBlackBox(art_regressor, attack_model_type='rf', input_type='loss')\n",
    "\n",
    "attack_train_ratio = 0.5\n",
    "attack_train_size = int(len(x_train) * attack_train_ratio)\n",
    "attack_test_size = int(len(x_test) * attack_train_ratio)\n",
    "\n",
    "\n",
    "# train attack model\n",
    "bb_attack.fit(x_train[:attack_train_size], y_train[:attack_train_size],\n",
    "              x_test[:attack_test_size], y_test[:attack_test_size])\n",
    "\n",
    "# infer \n",
    "inferred_train_bb = bb_attack.infer(x_train.astype(np.float32), y_train)\n",
    "inferred_test_bb = bb_attack.infer(x_test.astype(np.float32), y_test)\n",
    "\n",
    "# check accuracy\n",
    "train_acc_bb = np.sum(inferred_train_bb) / len(inferred_train_bb)\n",
    "test_acc_bb = 1 - (np.sum(inferred_test_bb) / len(inferred_test_bb))\n",
    "acc_bb = (train_acc_bb * len(inferred_train_bb) + test_acc_bb * len(inferred_test_bb)) / (len(inferred_train_bb) + len(inferred_test_bb))\n",
    "print('Member accuracy:', train_acc_bb)\n",
    "print('Non-Member accuracy:', test_acc_bb)\n",
    "print('Accuracy:', acc_bb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.7804878048780488, 0.7239819004524887)\n"
     ]
    }
   ],
   "source": [
    "def calc_precision_recall(predicted, actual, positive_value=1):\n",
    "    score = 0  # both predicted and actual are positive\n",
    "    num_positive_predicted = 0  # predicted positive\n",
    "    num_positive_actual = 0  # actual positive\n",
    "    for i in range(len(predicted)):\n",
    "        if predicted[i] == positive_value:\n",
    "            num_positive_predicted += 1\n",
    "        if actual[i] == positive_value:\n",
    "            num_positive_actual += 1\n",
    "        if predicted[i] == actual[i]:\n",
    "            if predicted[i] == positive_value:\n",
    "                score += 1\n",
    "    \n",
    "    if num_positive_predicted == 0:\n",
    "        precision = 1\n",
    "    else:\n",
    "        precision = score / num_positive_predicted  # the fraction of predicted “Yes” responses that are correct\n",
    "    if num_positive_actual == 0:\n",
    "        recall = 1\n",
    "    else:\n",
    "        recall = score / num_positive_actual  # the fraction of “Yes” responses that are predicted correctly\n",
    "\n",
    "    return precision, recall\n",
    "\n",
    "# rule-based\n",
    "print(calc_precision_recall(np.concatenate((inferred_train_bb, inferred_test_bb)), \n",
    "                            np.concatenate((np.ones(len(inferred_train_bb)), np.zeros(len(inferred_test_bb))))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
